"""
Phase 2: Demographics -> Automation
====================================
Tests whether demographic structure (Z factors, old-age dependency,
youth dependency) predicts capital intensity, labor productivity,
and labor share. Follows Acemoglu & Restrepo (2022) hypothesis that
aging drives automation / capital deepening.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']


# ── Helpers ──────────────────────────────────────────────────────────

def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 2: DEMOGRAPHICS -> AUTOMATION")
    print("=" * 70)

    df = pd.read_csv(DATA / "automation_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Report available DVs
    dvs = {
        'capital_intensity': 'Investment/GDP',
        'labor_productivity': 'GDP per capita PPP',
        'log_labor_productivity': 'Log GDP per capita PPP',
        'labsh': 'Labor share (PWT)',
        'automation_proxy': 'Capital share (1 - labsh)',
        'capital_per_worker': 'Capital per worker (PWT)',
    }
    print("\n  Available dependent variables:")
    for v, desc in dvs.items():
        if v in df.columns:
            n = df[v].dropna().shape[0]
            print(f"    {v:30s} ({desc}): {n} obs")

    # ── Table 1: Z -> Capital Intensity ──
    print("\n" + "=" * 50)
    print("TABLE 1: Z -> CAPITAL INTENSITY (Investment/GDP)")
    print("=" * 50)

    results_ki = []

    # M1: Z only
    print("\n--- M1: Z only ---")
    r = run_panel_gls(df, 'capital_intensity',
                      ['Z_1', 'Z_2', 'Z_3'],
                      'M1: Z only')
    if r: results_ki.append(r)

    # M2: Z + controls
    print("\n--- M2: Z + controls ---")
    r = run_panel_gls(df, 'capital_intensity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'M2: Z + controls')
    if r: results_ki.append(r)

    # M3: Z + controls + Z x KAOPEN
    print("\n--- M3: Z + controls + Z x KAOPEN ---")
    interact = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    available_int = [v for v in interact if v in df.columns]
    r = run_panel_gls(df, 'capital_intensity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS + available_int,
                      'M3: Z x KAOPEN')
    if r: results_ki.append(r)

    write_table(results_ki, "demo_capital_intensity.md",
                "Demographics -> Capital Intensity (Investment/GDP)")

    # ── Table 2: Z -> Labor Productivity ──
    print("\n" + "=" * 50)
    print("TABLE 2: Z -> LABOR PRODUCTIVITY (Log GDP per capita)")
    print("=" * 50)

    results_lp = []

    # M1: Z only
    print("\n--- M1: Z only ---")
    r = run_panel_gls(df, 'log_labor_productivity',
                      ['Z_1', 'Z_2', 'Z_3'],
                      'M1: Z only')
    if r: results_lp.append(r)

    # M2: Z + controls
    print("\n--- M2: Z + controls ---")
    r = run_panel_gls(df, 'log_labor_productivity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'M2: Z + controls')
    if r: results_lp.append(r)

    # M3: Z + controls + Z x KAOPEN
    print("\n--- M3: Z + controls + Z x KAOPEN ---")
    r = run_panel_gls(df, 'log_labor_productivity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS + available_int,
                      'M3: Z x KAOPEN')
    if r: results_lp.append(r)

    write_table(results_lp, "demo_labor_productivity.md",
                "Demographics -> Labor Productivity (Log GDP per capita PPP)")

    # ── Table 3: Z -> Labor Share (if PWT available) ──
    if 'labsh' in df.columns and df['labsh'].dropna().shape[0] > 100:
        print("\n" + "=" * 50)
        print("TABLE 3: Z -> LABOR SHARE (PWT)")
        print("=" * 50)

        results_ls = []

        print("\n--- M1: Z only ---")
        r = run_panel_gls(df, 'labsh',
                          ['Z_1', 'Z_2', 'Z_3'],
                          'M1: Z only')
        if r: results_ls.append(r)

        print("\n--- M2: Z + controls ---")
        r = run_panel_gls(df, 'labsh',
                          ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                          'M2: Z + controls')
        if r: results_ls.append(r)

        print("\n--- M3: Z + controls + Z x KAOPEN ---")
        r = run_panel_gls(df, 'labsh',
                          ['Z_1', 'Z_2', 'Z_3'] + CONTROLS + available_int,
                          'M3: Z x KAOPEN')
        if r: results_ls.append(r)

        write_table(results_ls, "demo_labsh.md",
                    "Demographics -> Labor Share (PWT labsh)")
    else:
        print("\n  Skipping labor share table (PWT labsh not available)")

    # ── Table 4: OECD vs. Non-OECD Subsample ──
    print("\n" + "=" * 50)
    print("TABLE 4: OECD VS. NON-OECD SUBSAMPLE")
    print("=" * 50)

    results_oecd = []

    df_oecd = df[df['iso3'].isin(OECD)].copy()
    df_nooecd = df[~df['iso3'].isin(OECD)].copy()
    print(f"  OECD: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")
    print(f"  Non-OECD: {len(df_nooecd)} obs, {df_nooecd['iso3'].nunique()} countries")

    # Capital intensity: OECD
    print("\n--- OECD: Z -> capital_intensity ---")
    r = run_panel_gls(df_oecd, 'capital_intensity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'OECD: K-intensity')
    if r: results_oecd.append(r)

    # Capital intensity: Non-OECD
    print("\n--- Non-OECD: Z -> capital_intensity ---")
    r = run_panel_gls(df_nooecd, 'capital_intensity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'Non-OECD: K-intensity')
    if r: results_oecd.append(r)

    # Labor productivity: OECD
    print("\n--- OECD: Z -> log_labor_productivity ---")
    r = run_panel_gls(df_oecd, 'log_labor_productivity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'OECD: Productivity')
    if r: results_oecd.append(r)

    # Labor productivity: Non-OECD
    print("\n--- Non-OECD: Z -> log_labor_productivity ---")
    r = run_panel_gls(df_nooecd, 'log_labor_productivity',
                      ['Z_1', 'Z_2', 'Z_3'] + CONTROLS,
                      'Non-OECD: Productivity')
    if r: results_oecd.append(r)

    write_table(results_oecd, "demo_automation_oecd.md",
                "Demographics -> Automation: OECD vs. Non-OECD")

    # ── Table 5: Age Decomposition ──
    print("\n" + "=" * 50)
    print("TABLE 5: AGE DECOMPOSITION (old_dep vs. youth_dep)")
    print("=" * 50)

    results_age = []

    # Capital intensity: old_dep + youth_dep
    print("\n--- old_dep + youth_dep -> capital_intensity ---")
    r = run_panel_gls(df, 'capital_intensity',
                      ['old_dep', 'youth_dep'] + CONTROLS,
                      'K-intensity: Age')
    if r: results_age.append(r)

    # Capital intensity: old_dep only
    print("\n--- old_dep -> capital_intensity ---")
    r = run_panel_gls(df, 'capital_intensity',
                      ['old_dep'] + CONTROLS,
                      'K-intensity: Old')
    if r: results_age.append(r)

    # Labor productivity: old_dep + youth_dep
    print("\n--- old_dep + youth_dep -> log_labor_productivity ---")
    r = run_panel_gls(df, 'log_labor_productivity',
                      ['old_dep', 'youth_dep'] + CONTROLS,
                      'Productivity: Age')
    if r: results_age.append(r)

    # Labor productivity: old_dep only
    print("\n--- old_dep -> log_labor_productivity ---")
    r = run_panel_gls(df, 'log_labor_productivity',
                      ['old_dep'] + CONTROLS,
                      'Productivity: Old')
    if r: results_age.append(r)

    # Automation proxy (if available)
    if 'automation_proxy' in df.columns and df['automation_proxy'].dropna().shape[0] > 100:
        print("\n--- old_dep + youth_dep -> automation_proxy ---")
        r = run_panel_gls(df, 'automation_proxy',
                          ['old_dep', 'youth_dep'] + CONTROLS,
                          'Capital share: Age')
        if r: results_age.append(r)

    write_table(results_age, "demo_automation_age.md",
                "Demographics -> Automation: Age Decomposition")

    print("\n" + "=" * 70)
    print("PHASE 2 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
